1
0
Fork 0

Read incoming start element explicit part of sending outgoing start element.

This commit is contained in:
Matt Goodall 2012-07-11 16:09:48 +01:00
parent 30441cacf4
commit 318a97947b
3 changed files with 37 additions and 21 deletions

View File

@ -85,12 +85,12 @@ func startClient(stream *Stream, jid JID) error {
}, },
} }
if err := stream.SendStart(&start); err != nil { if rstart, err := stream.SendStart(&start); err != nil {
return err
}
if _, err := stream.Next(&xml.Name{nsStream, "stream"}); err != nil {
return err return err
} else {
if rstart.Name != (xml.Name{nsStream, "stream"}) {
return fmt.Errorf("unexpected start element: %s", rstart.Name)
}
} }
return nil return nil

View File

@ -33,24 +33,25 @@ func startComponent(stream *Stream, jid JID) (string, error) {
}, },
} }
if err := stream.SendStart(&start); err != nil { var streamId string
return "", err
}
streamId := "" if rstart, err := stream.SendStart(&start); err != nil {
if e, err := stream.Next(&xml.Name{nsStream, "stream"}); err != nil {
return "", err return "", err
} else { } else {
if rstart.Name != (xml.Name{nsStream, "stream"}) {
return "", fmt.Errorf("unexpected start element: %s", rstart.Name)
}
// Find the stream id. // Find the stream id.
for _, attr := range e.Attr { for _, attr := range rstart.Attr {
if attr.Name.Local == "id" { if attr.Name.Local == "id" {
streamId = attr.Value streamId = attr.Value
break break
} }
} }
if streamId == "" { }
return "", errors.New("Missing stream id")
} if streamId == "" {
return "", errors.New("Missing stream id")
} }
return streamId, nil return streamId, nil

View File

@ -55,12 +55,19 @@ func (stream *Stream) UpgradeTLS(config *tls.Config) error {
} }
// Send the element's start tag. Typically used to open the stream's document. // Send the element's start tag. Typically used to open the stream's document.
func (stream *Stream) SendStart(start *xml.StartElement) error { func (stream *Stream) SendStart(start *xml.StartElement) (*xml.StartElement, error) {
// Write start of outgoing doc.
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if err := writeXMLStartElement(buf, start); err != nil { if err := writeXMLStartElement(buf, start); err != nil {
return err return nil, err
} }
return stream.send(buf.Bytes()) if err := stream.send(buf.Bytes()); err != nil {
return nil, err
}
// Read and return start of incoming doc.
return nextStartElement(stream.dec)
} }
// Send a stanza. Used to write a complete, top-level element. // Send a stanza. Used to write a complete, top-level element.
@ -85,8 +92,19 @@ func (stream *Stream) send(b []byte) error {
// Bad things are very likely to happen if a call to Next() is successful but // Bad things are very likely to happen if a call to Next() is successful but
// you don't actually decode or skip the element. // you don't actually decode or skip the element.
func (stream *Stream) Next(match *xml.Name) (*xml.StartElement, error) { func (stream *Stream) Next(match *xml.Name) (*xml.StartElement, error) {
start, err := nextStartElement(stream.dec)
if err != nil {
return nil, err
}
if match != nil && start.Name != *match {
return nil, fmt.Errorf("Expected %s, got %s", *match, start.Name)
}
return start, nil
}
func nextStartElement(dec *xml.Decoder) (*xml.StartElement, error) {
for { for {
t, err := stream.dec.Token() t, err := dec.Token()
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
@ -95,9 +113,6 @@ func (stream *Stream) Next(match *xml.Name) (*xml.StartElement, error) {
} }
switch e := t.(type) { switch e := t.(type) {
case xml.StartElement: case xml.StartElement:
if match != nil && e.Name != *match {
return nil, fmt.Errorf("Expected %s, got %s", *match, e.Name)
}
return &e, nil return &e, nil
case xml.EndElement: case xml.EndElement:
log.Printf("EOF due to %s\n", e.Name) log.Printf("EOF due to %s\n", e.Name)