5G-LENA nr-v3.3-120-gdac69c56
The 5G/NR module for the ns-3 simulator
Loading...
Searching...
No Matches
nr-mac-scheduler-ofdma-ai.cc
1// Copyright (c) 2024 Seoul National University (SNU)
2// Copyright (c) 2024 Centre Tecnologic de Telecomunicacions de Catalunya (CTTC)
3//
4// SPDX-License-Identifier: GPL-2.0-only
5
6#include "nr-mac-scheduler-ofdma-ai.h"
7
8#include "ns3/boolean.h"
9#include "ns3/log.h"
10
11#include <algorithm>
12#include <functional>
13
14namespace ns3
15{
16NS_LOG_COMPONENT_DEFINE("NrMacSchedulerOfdmaAi");
17NS_OBJECT_ENSURE_REGISTERED(NrMacSchedulerOfdmaAi);
18
19TypeId
21{
22 static TypeId tid =
23 TypeId("ns3::NrMacSchedulerOfdmaAi")
24 .SetParent<NrMacSchedulerOfdmaQos>()
25 .AddConstructor<NrMacSchedulerOfdmaAi>()
26 .AddAttribute("NotifyCbDl",
27 "The callback function to notify the AI model for the downlink",
28 CallbackValue(MakeNullCallback<NrMacSchedulerUeInfoAi::NotifyCb>()),
29 MakeCallbackAccessor(&NrMacSchedulerOfdmaAi::m_notifyCbDl),
30 MakeCallbackChecker())
31 .AddAttribute("NotifyCbUl",
32 "The callback function to notify the AI model for the uplink",
33 CallbackValue(MakeNullCallback<NrMacSchedulerUeInfoAi::NotifyCb>()),
34 MakeCallbackAccessor(&NrMacSchedulerOfdmaAi::m_notifyCbUl),
35 MakeCallbackChecker())
36 .AddAttribute("ActiveDlAi",
37 "The flag to activate the AI model for the downlink",
38 BooleanValue(false),
39 MakeBooleanAccessor(&NrMacSchedulerOfdmaAi::m_activeDlAi),
40 MakeBooleanChecker())
41 .AddAttribute("ActiveUlAi",
42 "The flag to activate the AI model for the uplink",
43 BooleanValue(false),
44 MakeBooleanAccessor(&NrMacSchedulerOfdmaAi::m_activeUlAi),
45 MakeBooleanChecker());
46 return tid;
47}
48
53
54std::shared_ptr<NrMacSchedulerUeInfo>
57{
58 NS_LOG_FUNCTION(this);
59 return std::make_shared<NrMacSchedulerUeInfoAi>(
60 m_alpha,
61 params.m_rnti,
62 params.m_beamId,
64}
65
66std::function<bool(const NrMacSchedulerNs3::UePtrAndBufferReq& lhs,
76
77std::function<bool(const NrMacSchedulerNs3::UePtrAndBufferReq& lhs,
87
88void
90{
91 NS_LOG_FUNCTION(this);
92 m_notifyCbDl = notifyCb;
93 m_activeDlAi = true;
94}
95
96void
98{
99 NS_LOG_FUNCTION(this);
100 m_notifyCbUl = notifyCb;
101 m_activeUlAi = true;
102}
103
104std::vector<NrMacSchedulerUeInfoAi::LcObservation>
106 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
107{
108 NS_LOG_FUNCTION(this);
109 std::vector<NrMacSchedulerUeInfoAi::LcObservation> observations;
110 for (const auto& ue : ueVector)
111 {
112 auto uePtr = std::dynamic_pointer_cast<NrMacSchedulerUeInfoAi>(ue.first);
113 std::vector<NrMacSchedulerUeInfoAi::LcObservation> ueObservation =
114 uePtr->GetDlObservation();
115 observations.insert(observations.end(), ueObservation.begin(), ueObservation.end());
116 }
117 return observations;
118}
119
120std::vector<NrMacSchedulerUeInfoAi::LcObservation>
122 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
123{
124 NS_LOG_FUNCTION(this);
125 std::vector<NrMacSchedulerUeInfoAi::LcObservation> observations;
126 for (const auto& ue : ueVector)
127 {
128 auto uePtr = std::dynamic_pointer_cast<NrMacSchedulerUeInfoAi>(ue.first);
129 std::vector<NrMacSchedulerUeInfoAi::LcObservation> ueObservation =
130 uePtr->GetUlObservation();
131 observations.insert(observations.end(), ueObservation.begin(), ueObservation.end());
132 }
133 return observations;
134}
135
136bool
138{
139 return false;
140}
141
142bool
144{
145 return false;
146}
147
148float
150 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
151{
152 NS_LOG_FUNCTION(this);
153 float reward = 0.0;
154 for (const auto& ue : ueVector)
155 {
156 auto uePtr = std::dynamic_pointer_cast<NrMacSchedulerUeInfoAi>(ue.first);
157 reward += uePtr->GetDlReward();
158 }
159 return reward;
160}
161
162float
164 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
165{
166 NS_LOG_FUNCTION(this);
167 float reward = 0.0;
168 for (const auto& ue : ueVector)
169 {
170 auto uePtr = std::dynamic_pointer_cast<NrMacSchedulerUeInfoAi>(ue.first);
171 reward += uePtr->GetUlReward();
172 }
173 return reward;
174}
175
176void
178 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
179{
180 NS_LOG_FUNCTION(this);
181 if (!m_notifyCbDl.IsNull())
182 {
183 std::string extraInfo = "";
186 this,
187 std::placeholders::_1,
188 ueVector);
189 m_notifyCbDl(GetUeObservationsDl(ueVector),
191 GetUeRewardsDl(ueVector),
192 extraInfo,
193 updateWeightsFn);
194 }
195}
196
197void
199 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
200{
201 NS_LOG_FUNCTION(this);
202 if (!m_notifyCbUl.IsNull())
203 {
204 std::string extraInfo = "";
207 this,
208 std::placeholders::_1,
209 ueVector);
210 m_notifyCbUl(GetUeObservationsUl(ueVector),
212 GetUeRewardsUl(ueVector),
213 extraInfo,
214 updateWeightsFn);
215 }
216}
217
218void
221 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
222{
223 NS_LOG_FUNCTION(this);
224 for (const auto& ue : ueVector)
225 {
226 auto uePtr = std::dynamic_pointer_cast<NrMacSchedulerUeInfoAi>(ue.first);
227 NrMacSchedulerUeInfoAi::Weights weights = ueWeights.at(uePtr->m_rnti);
228 uePtr->UpdateDlWeights(weights);
229 }
230}
231
232void
235 const std::vector<NrMacSchedulerNs3::UePtrAndBufferReq>& ueVector) const
236{
237 NS_LOG_FUNCTION(this);
238 for (const auto& ue : ueVector)
239 {
240 auto uePtr = std::dynamic_pointer_cast<NrMacSchedulerUeInfoAi>(ue.first);
241 NrMacSchedulerUeInfoAi::Weights weights = ueWeights.at(uePtr->m_rnti);
242 uePtr->UpdateUlWeights(weights);
243 }
244}
245
246} // namespace ns3
bool m_activeUlAi
Flag for activating AI for uplink.
bool m_activeDlAi
Flag for activating AI for downlink.
uint64_t GetNumRbPerRbg() const
Private function that is used to get the number of resource blocks per resource block group and also ...
std::pair< UePtr, uint32_t > UePtrAndBufferReq
Pair between a pointer to NrMacSchedulerUeInfo and its buffer occupancy.
bool GetIsGameOverUl() const
Check if the uplink game is over.
float GetUeRewardsDl(const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const
Get rewards for downlink.
void SetNotifyCbDl(NrMacSchedulerUeInfoAi::NotifyCb notifyCb)
Set the notify callback function for downlink.
std::vector< NrMacSchedulerUeInfoAi::LcObservation > GetUeObservationsUl(const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const
Get UE observations for uplink.
void UpdateAllUeWeightsDl(const NrMacSchedulerUeInfoAi::UeWeightsMap &ueWeights, const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const
Update weights of all UEs for downlink.
std::function< bool(const NrMacSchedulerNs3::UePtrAndBufferReq &lhs, const NrMacSchedulerNs3::UePtrAndBufferReq &rhs)> GetUeCompareUlFn() const override
Return the comparison function to sort UL UEs according to the scheduler policy.
float GetUeRewardsUl(const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const
Get rewards for uplink.
void SetNotifyCbUl(NrMacSchedulerUeInfoAi::NotifyCb notifyCb)
Set the notify callback function for uplink.
bool GetIsGameOverDl() const
Check if the downlink game is over.
std::function< bool(const NrMacSchedulerNs3::UePtrAndBufferReq &lhs, const NrMacSchedulerNs3::UePtrAndBufferReq &rhs)> GetUeCompareDlFn() const override
Return the comparison function to sort DL UEs according to the scheduler policy.
std::shared_ptr< NrMacSchedulerUeInfo > CreateUeRepresentation(const NrMacCschedSapProvider::CschedUeConfigReqParameters &params) const override
Create an UE representation of the type NrMacSchedulerUeInfoAi.
static TypeId GetTypeId()
GetTypeId.
void UpdateAllUeWeightsUl(const NrMacSchedulerUeInfoAi::UeWeightsMap &ueWeights, const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const
Update weights of all UEs for uplink.
std::vector< NrMacSchedulerUeInfoAi::LcObservation > GetUeObservationsDl(const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const
Get UE observations for downlink.
void CallNotifyUlFn(const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const override
Call the notify callback function in the OpenGymEnv class in the ns3-gym module for uplink.
void CallNotifyDlFn(const std::vector< NrMacSchedulerNs3::UePtrAndBufferReq > &ueVector) const override
Call the notify callback function in the OpenGymEnv class in the ns3-gym module for downlink.
NrMacSchedulerOfdmaAi()
NrMacSchedulerOfdmaAi constructor.
Assign frequencies in QoS-based fashion.
static bool CompareUeWeightsUl(const NrMacSchedulerNs3::UePtrAndBufferReq &lue, const NrMacSchedulerNs3::UePtrAndBufferReq &rue)
comparison function object (i.e. an object that satisfies the requirements of Compare) which returns ...
static bool CompareUeWeightsDl(const NrMacSchedulerNs3::UePtrAndBufferReq &lue, const NrMacSchedulerNs3::UePtrAndBufferReq &rue)
comparison function object (i.e. an object that satisfies the requirements of Compare) which returns ...
std::unordered_map< uint8_t, double > Weights
A hash map for weights.
std::unordered_map< uint8_t, Weights > UeWeightsMap
A hash map for UE weights.
Callback< void, const std::vector< LcObservation > &, bool, float, const std::string &, const UpdateAllUeWeightsFn & > NotifyCb
A callback type for notifying with specific parameters.
std::function< void(const UeWeightsMap &)> UpdateAllUeWeightsFn
A function type for updating the weights of all UEs.
static bool CompareUeWeightsUl(const NrMacSchedulerNs3::UePtrAndBufferReq &lue, const NrMacSchedulerNs3::UePtrAndBufferReq &rue)
comparison function object (i.e. an object that satisfies the requirements of Compare) which returns ...
static bool CompareUeWeightsDl(const NrMacSchedulerNs3::UePtrAndBufferReq &lue, const NrMacSchedulerNs3::UePtrAndBufferReq &rue)
comparison function object (i.e. an object that satisfies the requirements of Compare) which returns ...