程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
您现在的位置: 程式師世界 >> 編程語言 >  >> 更多編程語言 >> Python

Notes on Python functions encountered in DSX RL

編輯:Python

1.zip() function
zip() Function to take iteratable objects as parameters , Package the corresponding elements in the object into tuples , And then return the objects made up of these tuples , The advantage is that it saves a lot of memory .

We can use list() Convert to output list .

If the number of elements in each iterator is inconsistent , Returns a list of the same length as the shortest object , utilize * Sign operator fit zip function , Tuples can be unzipped into lists .

>>> a = [1,2,3]
>>> b = [4,5,6]
>>> c = [4,5,6,7,8]
>>> zipped = zip(a,b) # Return an object 
>>> zipped
<zip object at 0x103abc288>
>>> list(zipped) # list() Convert to list 
[(1, 4), (2, 5), (3, 6)]
>>> list(zip(a,c)) # The number of elements corresponds to the shortest list 
[(1, 4), (2, 5), (3, 6)]
>>> a1, a2 = zip(*zip(a,b)) # And zip contrary ,zip(*) Can be understood as decompression , Returns a two-dimensional matrix 
>>> list(a1)
[1, 2, 3]
>>> list(a2)
[4, 5, 6]
>>>

https://www.runoob.com/python3/python3-func-zip.html

2.np.random.random() function
When no parameter is passed in, a 0-1 The random number
When the parameter is passed in, it returns shape For the parameter of 0-1 Array of random numbers

3.numpy.random.randint() function

numpy.random.randint(low, high=None, size=None, dtype='l')

The delta function is going to be , Returns a random integer number , Range from low ( Include ) Up to the top ( barring ), namely [low, high).
If no parameters are written high Value , Then return to [0,low) Value .

>>> np.random.randint(2, size=10)
array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0])
>>> np.random.randint(1, size=10)
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
>>> np.random.randint(5, size=(2, 4))
array([[4, 0, 2, 1],
[3, 2, 2, 0]])
>>>np.random.randint(2, high=10, size=(2,3))
array([[6, 8, 7],
[2, 5, 2]])

https://blog.csdn.net/u011851421/article/details/83544853
4.gather function

gather The function of the function can be interpreted according to index Parameters ( This is the index ) Returns the value of the corresponding position in the array
there b.gather() Writing and torch.gather(b) It can be written in any way , The focus is on two parameters ,dim and index

Low dimensional understanding
dim=0 Means to index by row , in other words index The value of represents the row number
dim=1 Means to index by column , Which means index The value of represents the column number
5.torch.distributions.Categorical

probs = torch.FloatTensor([0.9,0.2])
ac = torch.distributions.Categorical(probs)
print(ac)
for _ in range(5):
print(ac.sample())

Its function is to create a parameter probs For the standard category distribution , Samples are from “0,…,K-1” The integer of ,K yes probs Length of parameter . in other words , according to probs Probability , Take samples at the corresponding positions , The sampling returns the integer index of the position .

Take another look at rl Select actions according to policy network in :

 def take_action(self, state): # Random sampling according to the action probability distribution 
state = torch.tensor([state], dtype=torch.float).to(self.device) # 1*4
probs = self.policy_net(state) # 1*2
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample()
return action.item()


  1. 上一篇文章:
  2. 下一篇文章:
Copyright © 程式師世界 All Rights Reserved