diff --git a/src/xmpp/client.go b/src/xmpp/client.go index fe37912..b08aca7 100644 --- a/src/xmpp/client.go +++ b/src/xmpp/client.go @@ -85,12 +85,12 @@ func startClient(stream *Stream, jid JID) error { }, } - if err := stream.SendStart(&start); err != nil { - return err - } - - if _, err := stream.Next(&xml.Name{nsStream, "stream"}); err != nil { + if rstart, err := stream.SendStart(&start); err != nil { return err + } else { + if rstart.Name != (xml.Name{nsStream, "stream"}) { + return fmt.Errorf("unexpected start element: %s", rstart.Name) + } } return nil diff --git a/src/xmpp/component.go b/src/xmpp/component.go index bf0d1c1..d6c2877 100644 --- a/src/xmpp/component.go +++ b/src/xmpp/component.go @@ -33,24 +33,25 @@ func startComponent(stream *Stream, jid JID) (string, error) { }, } - if err := stream.SendStart(&start); err != nil { - return "", err - } + var streamId string - streamId := "" - if e, err := stream.Next(&xml.Name{nsStream, "stream"}); err != nil { + if rstart, err := stream.SendStart(&start); err != nil { return "", err } else { + if rstart.Name != (xml.Name{nsStream, "stream"}) { + return "", fmt.Errorf("unexpected start element: %s", rstart.Name) + } // Find the stream id. - for _, attr := range e.Attr { + for _, attr := range rstart.Attr { if attr.Name.Local == "id" { streamId = attr.Value break } } - if streamId == "" { - return "", errors.New("Missing stream id") - } + } + + if streamId == "" { + return "", errors.New("Missing stream id") } return streamId, nil diff --git a/src/xmpp/stream.go b/src/xmpp/stream.go index 213dca5..d1b37b1 100644 --- a/src/xmpp/stream.go +++ b/src/xmpp/stream.go @@ -55,12 +55,19 @@ func (stream *Stream) UpgradeTLS(config *tls.Config) error { } // Send the element's start tag. Typically used to open the stream's document. -func (stream *Stream) SendStart(start *xml.StartElement) error { +func (stream *Stream) SendStart(start *xml.StartElement) (*xml.StartElement, error) { + + // Write start of outgoing doc. buf := new(bytes.Buffer) if err := writeXMLStartElement(buf, start); err != nil { - return err + return nil, err } - return stream.send(buf.Bytes()) + if err := stream.send(buf.Bytes()); err != nil { + return nil, err + } + + // Read and return start of incoming doc. + return nextStartElement(stream.dec) } // Send a stanza. Used to write a complete, top-level element. @@ -85,8 +92,19 @@ func (stream *Stream) send(b []byte) error { // Bad things are very likely to happen if a call to Next() is successful but // you don't actually decode or skip the element. func (stream *Stream) Next(match *xml.Name) (*xml.StartElement, error) { + start, err := nextStartElement(stream.dec) + if err != nil { + return nil, err + } + if match != nil && start.Name != *match { + return nil, fmt.Errorf("Expected %s, got %s", *match, start.Name) + } + return start, nil +} + +func nextStartElement(dec *xml.Decoder) (*xml.StartElement, error) { for { - t, err := stream.dec.Token() + t, err := dec.Token() if err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF @@ -95,9 +113,6 @@ func (stream *Stream) Next(match *xml.Name) (*xml.StartElement, error) { } switch e := t.(type) { case xml.StartElement: - if match != nil && e.Name != *match { - return nil, fmt.Errorf("Expected %s, got %s", *match, e.Name) - } return &e, nil case xml.EndElement: log.Printf("EOF due to %s\n", e.Name)