Python Forum
what does code mean (python, pytorch)
Thread Rating:
  • 0 Vote(s) - 0 Average
  • 1
  • 2
  • 3
  • 4
  • 5
what does code mean (python, pytorch)
#1
Please, tell me what does code mean:
        for in_channel, out_channel, stride, num_block in [
            [       64,          64,     1,       2],
            [       64,         128,     2,       2],
            [      128,         256,     2,       2],
            [      256,         512,     2,       2],
        ]:
            self.encode.append(
                nn.Sequential(
                   Basic( in_channel, out_channel,  stride=stride, ),
                *[ Basic(out_channel, out_channel,  stride=1,      ) for i in range(1, num_block) ]
                )
            )
1) Why "*" is used? does it mean "and"
2) Why "for i in range(1, num_block)" is used?
3) There is class inside class (class Basic inside class net) does it mean, that method in class Basic will not be used while training net?
4) I thought that "out_channel" should be equal "in_channel" in the next layer. If class Basic will be run twice doeas it mean, that
[       64,          64,     1,       2],
            [       64,         128,     2,       2],
            [      128,         256,     2,       2],
            [      256,         512,     2,       2]
is not related to second run of class Basic?

P.S. Full code:
class Basic(nn.Module):
    def __init__(self, in_channel, out_channel, stride=1, is_shortcut=False):
        super(Basic, self).__init__()
        self.conv1 = nn.Conv2d( in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1   = BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1,      padding=1, bias=False)
        self.bn2   = BatchNorm2d(out_channel)

        self.is_shortcut =  in_channel != out_channel or stride!=1
        if self.is_shortcut:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, kernel_size=1, padding=0, stride=stride, bias=False),
                BatchNorm2d(out_channel)
            )

    def forward(self, x):
        if self.is_shortcut:
            shortcut = self.shortcut(x)
        else:
            shortcut = x

        x = self.bn1(self.conv1(x))
        x = F.relu(x,inplace=True)
        x = self.bn2(self.conv2(x)) +  shortcut
        x = F.relu(x,inplace=True)
        return x
class Net(nn.Module):

    def __init__(self, in_channel=3, num_class=4):
        super(Net, self).__init__()

        self.encode = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, bias=False),
                BatchNorm2d(64),
                nn.ReLU(inplace=True),
                #nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
            )
        ])

        for in_channel, out_channel, stride, num_block in [
            [       64,          64,     1,       2],
            [       64,         128,     2,       2],
            [      128,         256,     2,       2],
            [      256,         512,     2,       2],
        ]:
            self.encode.append(
                nn.Sequential(
                   Basic( in_channel, out_channel,  stride=stride, ),
                *[ Basic(out_channel, out_channel,  stride=1,      ) for i in range(1, num_block) ]
                )
            )


        self.decode = nn.ModuleList([
            Decode( 512+256, 256),
            Decode( 256+128, 128),
            Decode( 128+ 64,  64),
            Decode(  64+ 64,  32),
            Decode(  32+  0,  16),
        ])

        self.logit = nn.Conv2d(16, num_class, kernel_size=1)



    def forward(self, x):
        batch_size,C,H,W = x.shape

        x = self.encode[0](x) ;  e0=x #; print('encode[0] :', x.shape)
        x = F.max_pool2d(x, kernel_size=3,stride=2,padding=1)

        x = self.encode[1](x) ;  e1=x #; print('encode[1] :', x.shape)
        x = self.encode[2](x) ;  e2=x #; print('encode[2] :', x.shape)
        x = self.encode[3](x) ;  e3=x #; print('encode[3] :', x.shape)
        x = self.encode[4](x) ;  e4=x #; print('encode[4] :', x.shape)

        #exit(0)
        x = self.decode[0](x,e3)      #; print('decode[0] :', x.shape)
        x = self.decode[1](x,e2)      #; print('decode[1] :', x.shape)
        x = self.decode[2](x,e1)      #; print('decode[2] :', x.shape)
        x = self.decode[3](x,e0)      #; print('decode[3] :', x.shape)
        x = self.decode[4](x)         #; print('decode[3] :', x.shape)

        #x = F.dropout(x, 0.5, training=self.training)
        logit = self.logit(x)
        return logit
    

net = Net().cuda()
    net.load_state_dict(torch.load(CHECKPOINT_FILE, map_location=lambda storage, loc: storage))


net.eval()
    with torch.no_grad():
        logit = net(input)
        probability= torch.sigmoid(logit)

    print('input: ',input.shape)
    print('logit: ',logit.shape)
    print('')
Reply
#2

  1. The asterisk (*) in this case is "star notation". It's syntactic sugar to unpack the contents of the list adjoined to it. So, instead of passing the list as an argument to nn.Sequential(), it unpacks the list and passes in each list item instead.
  2. The "for...in" statement is for the list comprehension; it's actually "[...for...in]". List comprehensions are a great way to rapidly generate lists.
  3. The question is not clear. There are instances of Basic() inside of Net(), but they are both separate classes and there's no inheritance so their methods called separately.
  4. The question is not clear. The lists are all used in the script and passed to Basic() as arguments multiple times.
Reply
#3
(Mar-11-2020, 08:20 PM)stullis Wrote:
  1. The asterisk (*) in this case is "star notation". It's syntactic sugar to unpack the contents of the list adjoined to it. So, instead of passing the list as an argument to nn.Sequential(), it unpacks the list and passes in each list item instead.
  2. The "for...in" statement is for the list comprehension; it's actually "[...for...in]". List comprehensions are a great way to rapidly generate lists.
  3. The question is not clear. There are instances of Basic() inside of Net(), but they are both separate classes and there's no inheritance so their methods called separately.
  4. The question is not clear. The lists are all used in the script and passed to Basic() as arguments multiple times.

Thank you very much for your answer! However I just started studying python, so the answer is not clear for me.
1) Can you please write examples for your answer?
2) Can you tell me please what do you mean by "list item" related to class Basic in the example?
Reply


Possibly Related Threads…
Thread Author Replies Views Last Post
  How can I rearrange df as the nodes index in pytorch geometric manner? uqlsmey 0 509 Jul-31-2023, 11:28 AM
Last Post: uqlsmey
  Installing Pytorch Godserena 4 2,642 Jul-02-2020, 06:44 PM
Last Post: hussainmujtaba
  Anaconda pytorch error Prince_Bhatia 1 3,437 Jan-02-2018, 08:53 PM
Last Post: nilamo

Forum Jump:

User Panel Messages

Announcements
Announcement #1 8/1/2020
Announcement #2 8/2/2020
Announcement #3 8/6/2020